home *** CD-ROM | disk | FTP | other *** search
/ Cream of the Crop 26 / Cream of the Crop 26.iso / os2 / octa209s.zip / octave-2.09 / liboctave / CDiagMatrix.cc < prev    next >
C/C++ Source or Header  |  1996-10-12  |  18KB  |  867 lines

  1. // DiagMatrix manipulations.
  2. /*
  3.  
  4. Copyright (C) 1996 John W. Eaton
  5.  
  6. This file is part of Octave.
  7.  
  8. Octave is free software; you can redistribute it and/or modify it
  9. under the terms of the GNU General Public License as published by the
  10. Free Software Foundation; either version 2, or (at your option) any
  11. later version.
  12.  
  13. Octave is distributed in the hope that it will be useful, but WITHOUT
  14. ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  15. FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
  16. for more details.
  17.  
  18. You should have received a copy of the GNU General Public License
  19. along with Octave; see the file COPYING.  If not, write to the Free
  20. Software Foundation, 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
  21.  
  22. */
  23.  
  24. #if defined (__GNUG__)
  25. #pragma implementation
  26. #endif
  27.  
  28. #ifdef HAVE_CONFIG_H
  29. #include <config.h>
  30. #endif
  31.  
  32. #include <iostream.h>
  33.  
  34. #include "lo-error.h"
  35. #include "mx-base.h"
  36. #include "mx-inlines.cc"
  37. #include "oct-cmplx.h"
  38.  
  39. // Complex Diagonal Matrix class
  40.  
  41. ComplexDiagMatrix::ComplexDiagMatrix (const DiagMatrix& a)
  42.   : MDiagArray2<Complex> (a.rows (), a.cols ())
  43. {
  44.   for (int i = 0; i < length (); i++)
  45.     elem (i, i) = a.elem (i, i);
  46. }
  47.  
  48. bool
  49. ComplexDiagMatrix::operator == (const ComplexDiagMatrix& a) const
  50. {
  51.   if (rows () != a.rows () || cols () != a.cols ())
  52.     return 0;
  53.  
  54.   return equal (data (), a.data (), length ());
  55. }
  56.  
  57. bool
  58. ComplexDiagMatrix::operator != (const ComplexDiagMatrix& a) const
  59. {
  60.   return !(*this == a);
  61. }
  62.  
  63. ComplexDiagMatrix&
  64. ComplexDiagMatrix::fill (double val)
  65. {
  66.   for (int i = 0; i < length (); i++)
  67.     elem (i, i) = val;
  68.   return *this;
  69. }
  70.  
  71. ComplexDiagMatrix&
  72. ComplexDiagMatrix::fill (const Complex& val)
  73. {
  74.   for (int i = 0; i < length (); i++)
  75.     elem (i, i) = val;
  76.   return *this;
  77. }
  78.  
  79. ComplexDiagMatrix&
  80. ComplexDiagMatrix::fill (double val, int beg, int end)
  81. {
  82.   if (beg < 0 || end >= length () || end < beg)
  83.     {
  84.       (*current_liboctave_error_handler) ("range error for fill");
  85.       return *this;
  86.     }
  87.  
  88.   for (int i = beg; i <= end; i++)
  89.     elem (i, i) = val;
  90.  
  91.   return *this;
  92. }
  93.  
  94. ComplexDiagMatrix&
  95. ComplexDiagMatrix::fill (const Complex& val, int beg, int end)
  96. {
  97.   if (beg < 0 || end >= length () || end < beg)
  98.     {
  99.       (*current_liboctave_error_handler) ("range error for fill");
  100.       return *this;
  101.     }
  102.  
  103.   for (int i = beg; i <= end; i++)
  104.     elem (i, i) = val;
  105.  
  106.   return *this;
  107. }
  108.  
  109. ComplexDiagMatrix&
  110. ComplexDiagMatrix::fill (const ColumnVector& a)
  111. {
  112.   int len = length ();
  113.   if (a.length () != len)
  114.     {
  115.       (*current_liboctave_error_handler) ("range error for fill");
  116.       return *this;
  117.     }
  118.  
  119.   for (int i = 0; i < len; i++)
  120.     elem (i, i) = a.elem (i);
  121.  
  122.   return *this;
  123. }
  124.  
  125. ComplexDiagMatrix&
  126. ComplexDiagMatrix::fill (const ComplexColumnVector& a)
  127. {
  128.   int len = length ();
  129.   if (a.length () != len)
  130.     {
  131.       (*current_liboctave_error_handler) ("range error for fill");
  132.       return *this;
  133.     }
  134.  
  135.   for (int i = 0; i < len; i++)
  136.     elem (i, i) = a.elem (i);
  137.  
  138.   return *this;
  139. }
  140.  
  141. ComplexDiagMatrix&
  142. ComplexDiagMatrix::fill (const RowVector& a)
  143. {
  144.   int len = length ();
  145.   if (a.length () != len)
  146.     {
  147.       (*current_liboctave_error_handler) ("range error for fill");
  148.       return *this;
  149.     }
  150.  
  151.   for (int i = 0; i < len; i++)
  152.     elem (i, i) = a.elem (i);
  153.  
  154.   return *this;
  155. }
  156.  
  157. ComplexDiagMatrix&
  158. ComplexDiagMatrix::fill (const ComplexRowVector& a)
  159. {
  160.   int len = length ();
  161.   if (a.length () != len)
  162.     {
  163.       (*current_liboctave_error_handler) ("range error for fill");
  164.       return *this;
  165.     }
  166.  
  167.   for (int i = 0; i < len; i++)
  168.     elem (i, i) = a.elem (i);
  169.  
  170.   return *this;
  171. }
  172.  
  173. ComplexDiagMatrix&
  174. ComplexDiagMatrix::fill (const ColumnVector& a, int beg)
  175. {
  176.   int a_len = a.length ();
  177.   if (beg < 0 || beg + a_len >= length ())
  178.     {
  179.       (*current_liboctave_error_handler) ("range error for fill");
  180.       return *this;
  181.     }
  182.  
  183.   for (int i = 0; i < a_len; i++)
  184.     elem (i+beg, i+beg) = a.elem (i);
  185.  
  186.   return *this;
  187. }
  188.  
  189. ComplexDiagMatrix&
  190. ComplexDiagMatrix::fill (const ComplexColumnVector& a, int beg)
  191. {
  192.   int a_len = a.length ();
  193.   if (beg < 0 || beg + a_len >= length ())
  194.     {
  195.       (*current_liboctave_error_handler) ("range error for fill");
  196.       return *this;
  197.     }
  198.  
  199.   for (int i = 0; i < a_len; i++)
  200.     elem (i+beg, i+beg) = a.elem (i);
  201.  
  202.   return *this;
  203. }
  204.  
  205. ComplexDiagMatrix&
  206. ComplexDiagMatrix::fill (const RowVector& a, int beg)
  207. {
  208.   int a_len = a.length ();
  209.   if (beg < 0 || beg + a_len >= length ())
  210.     {
  211.       (*current_liboctave_error_handler) ("range error for fill");
  212.       return *this;
  213.     }
  214.  
  215.   for (int i = 0; i < a_len; i++)
  216.     elem (i+beg, i+beg) = a.elem (i);
  217.  
  218.   return *this;
  219. }
  220.  
  221. ComplexDiagMatrix&
  222. ComplexDiagMatrix::fill (const ComplexRowVector& a, int beg)
  223. {
  224.   int a_len = a.length ();
  225.   if (beg < 0 || beg + a_len >= length ())
  226.     {
  227.       (*current_liboctave_error_handler) ("range error for fill");
  228.       return *this;
  229.     }
  230.  
  231.   for (int i = 0; i < a_len; i++)
  232.     elem (i+beg, i+beg) = a.elem (i);
  233.  
  234.   return *this;
  235. }
  236.  
  237. ComplexDiagMatrix
  238. ComplexDiagMatrix::hermitian (void) const
  239. {
  240.   return ComplexDiagMatrix (conj_dup (data (), length ()), cols (), rows ());
  241. }
  242.  
  243. ComplexDiagMatrix
  244. ComplexDiagMatrix::transpose (void) const
  245. {
  246.   return ComplexDiagMatrix (dup (data (), length ()), cols (), rows ());
  247. }
  248.  
  249. ComplexDiagMatrix
  250. conj (const ComplexDiagMatrix& a)
  251. {
  252.   ComplexDiagMatrix retval;
  253.   int a_len = a.length ();
  254.   if (a_len > 0)
  255.     retval = ComplexDiagMatrix (conj_dup (a.data (), a_len),
  256.                 a.rows (), a.cols ());
  257.   return retval;
  258. }
  259.  
  260. // resize is the destructive analog for this one
  261.  
  262. ComplexMatrix
  263. ComplexDiagMatrix::extract (int r1, int c1, int r2, int c2) const
  264. {
  265.   if (r1 > r2) { int tmp = r1; r1 = r2; r2 = tmp; }
  266.   if (c1 > c2) { int tmp = c1; c1 = c2; c2 = tmp; }
  267.  
  268.   int new_r = r2 - r1 + 1;
  269.   int new_c = c2 - c1 + 1;
  270.  
  271.   ComplexMatrix result (new_r, new_c);
  272.  
  273.   for (int j = 0; j < new_c; j++)
  274.     for (int i = 0; i < new_r; i++)
  275.       result.elem (i, j) = elem (r1+i, c1+j);
  276.  
  277.   return result;
  278. }
  279.  
  280. // extract row or column i.
  281.  
  282. ComplexRowVector
  283. ComplexDiagMatrix::row (int i) const
  284. {
  285.   int nr = rows ();
  286.   int nc = cols ();
  287.   if (i < 0 || i >= nr)
  288.     {
  289.       (*current_liboctave_error_handler) ("invalid row selection");
  290.       return RowVector (); 
  291.     }
  292.  
  293.   ComplexRowVector retval (nc, 0.0);
  294.   if (nr <= nc || (nr > nc && i < nc))
  295.     retval.elem (i) = elem (i, i);
  296.  
  297.   return retval;
  298. }
  299.  
  300. ComplexRowVector
  301. ComplexDiagMatrix::row (char *s) const
  302. {
  303.   if (! s)
  304.     {
  305.       (*current_liboctave_error_handler) ("invalid row selection");
  306.       return ComplexRowVector (); 
  307.     }
  308.  
  309.   char c = *s;
  310.   if (c == 'f' || c == 'F')
  311.     return row (0);
  312.   else if (c == 'l' || c == 'L')
  313.     return row (rows () - 1);
  314.   else
  315.     {
  316.       (*current_liboctave_error_handler) ("invalid row selection");
  317.       return ComplexRowVector ();
  318.     }
  319. }
  320.  
  321. ComplexColumnVector
  322. ComplexDiagMatrix::column (int i) const
  323. {
  324.   int nr = rows ();
  325.   int nc = cols ();
  326.   if (i < 0 || i >= nc)
  327.     {
  328.       (*current_liboctave_error_handler) ("invalid column selection");
  329.       return ColumnVector (); 
  330.     }
  331.  
  332.   ComplexColumnVector retval (nr, 0.0);
  333.   if (nr >= nc || (nr < nc && i < nr))
  334.     retval.elem (i) = elem (i, i);
  335.  
  336.   return retval;
  337. }
  338.  
  339. ComplexColumnVector
  340. ComplexDiagMatrix::column (char *s) const
  341. {
  342.   if (! s)
  343.     {
  344.       (*current_liboctave_error_handler) ("invalid column selection");
  345.       return ColumnVector (); 
  346.     }
  347.  
  348.   char c = *s;
  349.   if (c == 'f' || c == 'F')
  350.     return column (0);
  351.   else if (c == 'l' || c == 'L')
  352.     return column (cols () - 1);
  353.   else
  354.     {
  355.       (*current_liboctave_error_handler) ("invalid column selection");
  356.       return ColumnVector (); 
  357.     }
  358. }
  359.  
  360. ComplexDiagMatrix
  361. ComplexDiagMatrix::inverse (void) const
  362. {
  363.   int info;
  364.   return inverse (info);
  365. }
  366.  
  367. ComplexDiagMatrix
  368. ComplexDiagMatrix::inverse (int& info) const
  369. {
  370.   int nr = rows ();
  371.   int nc = cols ();
  372.   if (nr != nc)
  373.     {
  374.       (*current_liboctave_error_handler) ("inverse requires square matrix");
  375.       return DiagMatrix ();
  376.     }
  377.  
  378.   ComplexDiagMatrix retval (nr, nc);
  379.  
  380.   info = 0;
  381.   for (int i = 0; i < length (); i++)
  382.     {
  383.       if (elem (i, i) == 0.0)
  384.     {
  385.       info = -1;
  386.       return *this;
  387.     }
  388.       else
  389.     retval.elem (i, i) = 1.0 / elem (i, i);
  390.     }
  391.  
  392.   return retval;
  393. }
  394.  
  395. // diagonal matrix by diagonal matrix -> diagonal matrix operations
  396.  
  397. ComplexDiagMatrix&
  398. ComplexDiagMatrix::operator += (const DiagMatrix& a)
  399. {
  400.   int nr = rows ();
  401.   int nc = cols ();
  402.  
  403.   int a_nr = a.rows ();
  404.   int a_nc = a.cols ();
  405.  
  406.   if (nr != a_nr || nc != a_nc)
  407.     {
  408.       gripe_nonconformant ("operator +=", nr, nc, a_nr, a_nc);
  409.       return *this;
  410.     }
  411.  
  412.   if (nr == 0 || nc == 0)
  413.     return *this;
  414.  
  415.   Complex *d = fortran_vec (); // Ensures only one reference to my privates!
  416.  
  417.   add2 (d, a.data (), length ());
  418.   return *this;
  419. }
  420.  
  421. ComplexDiagMatrix&
  422. ComplexDiagMatrix::operator -= (const DiagMatrix& a)
  423. {
  424.   int nr = rows ();
  425.   int nc = cols ();
  426.  
  427.   int a_nr = a.rows ();
  428.   int a_nc = a.cols ();
  429.  
  430.   if (nr != a_nr || nc != a_nc)
  431.     {
  432.       gripe_nonconformant ("operator -=", nr, nc, a_nr, a_nc);
  433.       return *this;
  434.     }
  435.  
  436.   if (nr == 0 || nc == 0)
  437.     return *this;
  438.  
  439.   Complex *d = fortran_vec (); // Ensures only one reference to my privates!
  440.  
  441.   subtract2 (d, a.data (), length ());
  442.   return *this;
  443. }
  444.  
  445. ComplexDiagMatrix&
  446. ComplexDiagMatrix::operator += (const ComplexDiagMatrix& a)
  447. {
  448.   int nr = rows ();
  449.   int nc = cols ();
  450.  
  451.   int a_nr = a.rows ();
  452.   int a_nc = a.cols ();
  453.  
  454.   if (nr != a_nr || nc != a_nc)
  455.     {
  456.       gripe_nonconformant ("operator +=", nr, nc, a_nr, a_nc);
  457.       return *this;
  458.     }
  459.  
  460.   if (nr == 0 || nc == 0)
  461.     return *this;
  462.  
  463.   Complex *d = fortran_vec (); // Ensures only one reference to my privates!
  464.  
  465.   add2 (d, a.data (), length ());
  466.   return *this;
  467. }
  468.  
  469. ComplexDiagMatrix&
  470. ComplexDiagMatrix::operator -= (const ComplexDiagMatrix& a)
  471. {
  472.   int nr = rows ();
  473.   int nc = cols ();
  474.  
  475.   int a_nr = a.rows ();
  476.   int a_nc = a.cols ();
  477.  
  478.   if (nr != a_nr || nc != a_nc)
  479.     {
  480.       gripe_nonconformant ("operator -=", nr, nc, a_nr, a_nc);
  481.       return *this;
  482.     }
  483.  
  484.   if (nr == 0 || nc == 0)
  485.     return *this;
  486.  
  487.   Complex *d = fortran_vec (); // Ensures only one reference to my privates!
  488.  
  489.   subtract2 (d, a.data (), length ());
  490.   return *this;
  491. }
  492.  
  493. // diagonal matrix by scalar -> diagonal matrix operations
  494.  
  495. ComplexDiagMatrix
  496. operator * (const ComplexDiagMatrix& a, double s)
  497. {
  498.   return ComplexDiagMatrix (multiply (a.data (), a.length (), s),
  499.                 a.rows (), a.cols ());
  500. }
  501.  
  502. ComplexDiagMatrix
  503. operator / (const ComplexDiagMatrix& a, double s)
  504. {
  505.   return ComplexDiagMatrix (divide (a.data (), a.length (), s),
  506.                 a.rows (), a.cols ());
  507. }
  508.  
  509. ComplexDiagMatrix
  510. operator * (const DiagMatrix& a, const Complex& s)
  511. {
  512.   return ComplexDiagMatrix (multiply (a.data (), a.length (), s),
  513.                 a.rows (), a.cols ());
  514. }
  515.  
  516. ComplexDiagMatrix
  517. operator / (const DiagMatrix& a, const Complex& s)
  518. {
  519.   return ComplexDiagMatrix (divide (a.data (), a.length (), s),
  520.                 a.rows (), a.cols ());
  521. }
  522.  
  523. // scalar by diagonal matrix -> diagonal matrix operations
  524.  
  525. ComplexDiagMatrix
  526. operator * (double s, const ComplexDiagMatrix& a)
  527. {
  528.   return ComplexDiagMatrix (multiply (a.data (), a.length (), s),
  529.                 a.rows (), a.cols ());
  530. }
  531.  
  532. ComplexDiagMatrix
  533. operator * (const Complex& s, const DiagMatrix& a)
  534. {
  535.   return ComplexDiagMatrix (multiply (a.data (), a.length (), s),
  536.                 a.rows (), a.cols ());
  537. }
  538.  
  539. // diagonal matrix by diagonal matrix -> diagonal matrix operations
  540.  
  541. ComplexDiagMatrix
  542. operator * (const ComplexDiagMatrix& a, const ComplexDiagMatrix& b)
  543. {
  544.   int nr_a = a.rows ();
  545.   int nc_a = a.cols ();
  546.  
  547.   int nr_b = b.rows ();
  548.   int nc_b = b.cols ();
  549.  
  550.   if (nc_a != nr_b)
  551.     {
  552.       gripe_nonconformant ("operator *", nr_a, nc_a, nr_b, nc_b);
  553.       return ComplexDiagMatrix ();
  554.     }
  555.  
  556.   if (nr_a == 0 || nc_a == 0 || nc_b == 0)
  557.     return ComplexDiagMatrix (nr_a, nc_a, 0.0);
  558.  
  559.   ComplexDiagMatrix c (nr_a, nc_b);
  560.  
  561.   int len = nr_a < nc_b ? nr_a : nc_b;
  562.  
  563.   for (int i = 0; i < len; i++)
  564.     {
  565.       Complex a_element = a.elem (i, i);
  566.       Complex b_element = b.elem (i, i);
  567.  
  568.       if (a_element == 0.0 || b_element == 0.0)
  569.         c.elem (i, i) = 0.0;
  570.       else if (a_element == 1.0)
  571.         c.elem (i, i) = b_element;
  572.       else if (b_element == 1.0)
  573.         c.elem (i, i) = a_element;
  574.       else
  575.         c.elem (i, i) = a_element * b_element;
  576.     }
  577.  
  578.   return c;
  579. }
  580.  
  581. ComplexDiagMatrix
  582. operator + (const ComplexDiagMatrix& m, const DiagMatrix& a)
  583. {
  584.   int nr = m.rows ();
  585.   int nc = m.cols ();
  586.  
  587.   int a_nr = a.rows ();
  588.   int a_nc = a.cols ();
  589.  
  590.   if (nr != a_nr || nc != a_nc)
  591.     {
  592.       gripe_nonconformant ("operator +", nr, nc, a_nr, a_nc);
  593.       return ComplexDiagMatrix ();
  594.     }
  595.  
  596.   if (nr == 0 || nc == 0)
  597.     return ComplexDiagMatrix (nr, nc);
  598.  
  599.   return ComplexDiagMatrix (add (m.data (), a.data (), m.length ()), nr, nc);
  600. }
  601.  
  602. ComplexDiagMatrix
  603. operator - (const ComplexDiagMatrix& m, const DiagMatrix& a)
  604. {
  605.   int nr = m.rows ();
  606.   int nc = m.cols ();
  607.  
  608.   int a_nr = a.rows ();
  609.   int a_nc = a.cols ();
  610.  
  611.   if (nr != a_nr || nc != a_nc)
  612.     {
  613.       gripe_nonconformant ("operator -", nr, nc, a_nr, a_nc);
  614.       return ComplexDiagMatrix ();
  615.     }
  616.  
  617.   if (nr == 0 || nc == 0)
  618.     return ComplexDiagMatrix (nr, nc);
  619.  
  620.   return ComplexDiagMatrix (subtract (m.data (), a.data (), m.length ()),
  621.                 nr, nc);
  622. }
  623.  
  624. ComplexDiagMatrix
  625. operator * (const ComplexDiagMatrix& a, const DiagMatrix& b)
  626. {
  627.   int nr_a = a.rows ();
  628.   int nc_a = a.cols ();
  629.  
  630.   int nr_b = b.rows ();
  631.   int nc_b = b.cols ();
  632.  
  633.   if (nc_a != nr_b)
  634.     {
  635.       gripe_nonconformant ("operator *", nr_a, nc_a, nr_b, nc_b);
  636.       return ComplexDiagMatrix ();
  637.     }
  638.  
  639.   if (nr_a == 0 || nc_a == 0 || nc_b == 0)
  640.     return ComplexDiagMatrix (nr_a, nc_a, 0.0);
  641.  
  642.   ComplexDiagMatrix c (nr_a, nc_b);
  643.  
  644.   int len = nr_a < nc_b ? nr_a : nc_b;
  645.  
  646.   for (int i = 0; i < len; i++)
  647.     {
  648.       Complex a_element = a.elem (i, i);
  649.       double b_element = b.elem (i, i);
  650.  
  651.       if (a_element == 0.0 || b_element == 0.0)
  652.         c.elem (i, i) = 0.0;
  653.       else if (a_element == 1.0)
  654.         c.elem (i, i) = b_element;
  655.       else if (b_element == 1.0)
  656.         c.elem (i, i) = a_element;
  657.       else
  658.         c.elem (i, i) = a_element * b_element;
  659.     }
  660.  
  661.   return c;
  662. }
  663.  
  664. ComplexDiagMatrix
  665. operator + (const DiagMatrix& m, const ComplexDiagMatrix& a)
  666. {
  667.   int nr = m.rows ();
  668.   int nc = m.cols ();
  669.  
  670.   int a_nr = a.rows ();
  671.   int a_nc = a.cols ();
  672.  
  673.   if (nr != a_nr || nc != a_nc)
  674.     {
  675.       gripe_nonconformant ("operator +", nr, nc, a_nr, a_nc);
  676.       return ComplexDiagMatrix ();
  677.     }
  678.  
  679.   if (nc == 0 || nr == 0)
  680.     return ComplexDiagMatrix (nr, nc);
  681.  
  682.   return ComplexDiagMatrix (add (m.data (), a.data (), m.length ()),  nr, nc);
  683. }
  684.  
  685. ComplexDiagMatrix
  686. operator - (const DiagMatrix& m, const ComplexDiagMatrix& a)
  687. {
  688.   int nr = m.rows ();
  689.   int nc = m.cols ();
  690.  
  691.   int a_nr = a.rows ();
  692.   int a_nc = a.cols ();
  693.  
  694.   if (nr != a_nr || nc != a_nc)
  695.     {
  696.       gripe_nonconformant ("operator -", nr, nc, a_nr, a_nc);
  697.       return ComplexDiagMatrix ();
  698.     }
  699.  
  700.   if (nc == 0 || nr == 0)
  701.     return ComplexDiagMatrix (nr, nc);
  702.  
  703.   return ComplexDiagMatrix (subtract (m.data (), a.data (), m.length ()),
  704.                 nr, nc);
  705. }
  706.  
  707. ComplexDiagMatrix
  708. operator * (const DiagMatrix& a, const ComplexDiagMatrix& b)
  709. {
  710.   int nr_a = a.rows ();
  711.   int nc_a = a.cols ();
  712.  
  713.   int nr_b = b.rows ();
  714.   int nc_b = b.cols ();
  715.  
  716.   if (nc_a != nr_b)
  717.     {
  718.       gripe_nonconformant ("operator *", nr_a, nc_a, nr_b, nc_b);
  719.       return ComplexDiagMatrix ();
  720.     }
  721.  
  722.   if (nr_a == 0 || nc_a == 0 || nc_b == 0)
  723.     return ComplexDiagMatrix (nr_a, nc_a, 0.0);
  724.  
  725.   ComplexDiagMatrix c (nr_a, nc_b);
  726.  
  727.   int len = nr_a < nc_b ? nr_a : nc_b;
  728.  
  729.   for (int i = 0; i < len; i++)
  730.     {
  731.       double a_element = a.elem (i, i);
  732.       Complex b_element = b.elem (i, i);
  733.  
  734.       if (a_element == 0.0 || b_element == 0.0)
  735.         c.elem (i, i) = 0.0;
  736.       else if (a_element == 1.0)
  737.         c.elem (i, i) = b_element;
  738.       else if (b_element == 1.0)
  739.         c.elem (i, i) = a_element;
  740.       else
  741.         c.elem (i, i) = a_element * b_element;
  742.     }
  743.  
  744.   return c;
  745. }
  746.  
  747. ComplexDiagMatrix
  748. product (const ComplexDiagMatrix& m, const DiagMatrix& a)
  749. {
  750.   int nr = m.rows ();
  751.   int nc = m.cols ();
  752.  
  753.   int a_nr = a.rows ();
  754.   int a_nc = a.cols ();
  755.  
  756.   if (nr != a_nr || nc != a_nc)
  757.     {
  758.       gripe_nonconformant ("product", nr, nc, a_nr, a_nc);
  759.       return ComplexDiagMatrix ();
  760.     }
  761.  
  762.   if (nr == 0 || nc == 0)
  763.     return ComplexDiagMatrix (nr, nc);
  764.  
  765.   return ComplexDiagMatrix (multiply (m.data (), a.data (), m.length ()),
  766.                 nr, nc);
  767. }
  768.  
  769. ComplexDiagMatrix
  770. product (const DiagMatrix& m, const ComplexDiagMatrix& a)
  771. {
  772.   int nr = m.rows ();
  773.   int nc = m.cols ();
  774.  
  775.   int a_nr = a.rows ();
  776.   int a_nc = a.cols ();
  777.  
  778.   if (nr != a_nr || nc != a_nc)
  779.     {
  780.       gripe_nonconformant ("product", nr, nc, a_nr, a_nc);
  781.       return ComplexDiagMatrix ();
  782.     }
  783.  
  784.   if (nc == 0 || nr == 0)
  785.     return ComplexDiagMatrix (nr, nc);
  786.  
  787.   return ComplexDiagMatrix (multiply (m.data (), a.data (), m.length ()),
  788.                 nr, nc);
  789. }
  790.  
  791. // other operations
  792.  
  793. ComplexColumnVector
  794. ComplexDiagMatrix::diag (void) const
  795. {
  796.   return diag (0);
  797. }
  798.  
  799. // Could be optimized...
  800.  
  801. ComplexColumnVector
  802. ComplexDiagMatrix::diag (int k) const
  803. {
  804.   int nnr = rows ();
  805.   int nnc = cols ();
  806.   if (k > 0)
  807.     nnc -= k;
  808.   else if (k < 0)
  809.     nnr += k;
  810.  
  811.   ComplexColumnVector d;
  812.  
  813.   if (nnr > 0 && nnc > 0)
  814.     {
  815.       int ndiag = (nnr < nnc) ? nnr : nnc;
  816.  
  817.       d.resize (ndiag);
  818.  
  819.       if (k > 0)
  820.     {
  821.       for (int i = 0; i < ndiag; i++)
  822.         d.elem (i) = elem (i, i+k);
  823.     }
  824.       else if ( k < 0)
  825.     {
  826.       for (int i = 0; i < ndiag; i++)
  827.         d.elem (i) = elem (i-k, i);
  828.     }
  829.       else
  830.     {
  831.       for (int i = 0; i < ndiag; i++)
  832.         d.elem (i) = elem (i, i);
  833.     }
  834.     }
  835.   else
  836.     cerr << "diag: requested diagonal out of range\n";
  837.  
  838.   return d;
  839. }
  840.  
  841. // i/o
  842.  
  843. ostream&
  844. operator << (ostream& os, const ComplexDiagMatrix& a)
  845. {
  846.   Complex ZERO (0.0);
  847. //  int field_width = os.precision () + 7;
  848.   for (int i = 0; i < a.rows (); i++)
  849.     {
  850.       for (int j = 0; j < a.cols (); j++)
  851.     {
  852.       if (i == j)
  853.         os << " " /* setw (field_width) */ << a.elem (i, i);
  854.       else
  855.         os << " " /* setw (field_width) */ << ZERO;
  856.     }
  857.       os << "\n";
  858.     }
  859.   return os;
  860. }
  861.  
  862. /*
  863. ;;; Local Variables: ***
  864. ;;; mode: C++ ***
  865. ;;; End: ***
  866. */
  867.